#!/usr/bin/env python

import os
import json
import logging
import sys
import traceback
import pandas as pd
import numpy as np
import random
import datetime
from pathlib import Path
import torch
import shutil
import tarfile
from collections import OrderedDict
from transformers import AutoTokenizer

from fast_bert.data_cls import BertDataBunch
from fast_bert.learner_cls import BertLearner
from fast_bert.metrics import (
    accuracy,
    accuracy_multilabel,
    accuracy_thresh,
    fbeta,
    roc_auc,
)

run_start_time = datetime.datetime.today().strftime("%Y-%m-%d_%H-%M-%S")
# Logger
# logfile = str(LOG_PATH/'log-{}-{}.txt'.format(run_start_time, training_config["run_text"]))
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
    datefmt="%m/%d/%Y %H:%M:%S",
    handlers=[
        # logging.FileHandler(logfile),
        logging.StreamHandler(sys.stdout)
    ],
)

logger = logging.getLogger()

channel_name = "training"

prefix = "/opt/ml/"
input_path = prefix + "input/data"  # opt/ml/input/data
output_path = os.path.join(prefix, "model")

finetuned_path = input_path + "/{}/finetuned".format(
    channel_name
)  # opt/ml/input/data/training/finetuned

training_config_path = os.path.join(
    input_path, "{}/config".format(channel_name)
)  # opt/ml/input/data/training/config

hyperparam_path = os.path.join(
    prefix, "input/config/hyperparameters.json"
)  # opt/ml/input/config/hyperparameters.json
config_path = os.path.join(
    training_config_path, "training_config.json"
)  # opt/ml/input/data/training/config/training_config.json


# This algorithm has a single channel of input data called 'training'. Since we run in
# File mode, the input files are copied to the directory specified here.

training_path = os.path.join(input_path, channel_name)  # opt/ml/input/data/training


# The function to execute the training.
def train():
    logger.info("Starting batch inference...")

    DATA_PATH = Path(training_path)
    MODEL_PATH = DATA_PATH / "model"
    ARTIFACTS_PATH = MODEL_PATH / "model_out"

    # untar model.tar.gz to model directory
    with tarfile.open(DATA_PATH / "model.tar.gz", "r:gz") as tar:
        tar.extractall(MODEL_PATH)
        tar.close()

    try:
        with open(config_path, "r") as f:
            training_config = json.load(f)
            logger.info(training_config)

        with open(hyperparam_path, "r") as tc:
            hyperparameters = json.load(tc)
            logger.info(hyperparameters)

        # convert string bools to booleans
        training_config["multi_label"] = training_config["multi_label"] == "True"
        training_config["fp16"] = training_config["fp16"] == "True"
        training_config["text_col"] = training_config.get("text_col", "text")
        training_config["label_col"] = training_config.get("label_col", "label")
        training_config["train_file"] = training_config.get("train_file", "train.csv")
        training_config["val_file"] = training_config.get("val_file", "val.csv")
        training_config["label_file"] = training_config.get("label_file", "labels.csv")
        training_config["random_state"] = training_config.get("random_state", None)
        training_config["labels_count"] = int(training_config.get("labels_count", 10))
        if training_config["random_state"] is not None:
            print("setting random state {}".format(training_config["random_state"]))
            random_seed(int(training_config["random_state"]))

        # use auto-tokenizer
        tokenizer = AutoTokenizer.from_pretrained(str(ARTIFACTS_PATH), use_fast=True)

        device = torch.device("cuda")
        if torch.cuda.device_count() > 1:
            multi_gpu = True
        else:
            multi_gpu = False

        logger.info("Number of GPUs: {}".format(torch.cuda.device_count()))

        # Create databunch
        databunch = BertDataBunch(
            MODEL_PATH,
            MODEL_PATH,
            tokenizer,
            train_file=None,
            val_file=None,
            batch_size_per_gpu=int(hyperparameters["train_batch_size"]),
            max_seq_length=int(hyperparameters["max_seq_length"]),
            multi_gpu=multi_gpu,
            multi_label=training_config["multi_label"],
            model_type=training_config["model_type"],
            logger=logger,
            no_cache=True,
        )

        # Initialise the learner
        learner = BertLearner.from_pretrained_model(
            databunch,
            str(ARTIFACTS_PATH),
            metrics=[],
            device=device,
            logger=logger,
            output_dir=None,
            is_fp16=False,
            multi_gpu=multi_gpu,
            multi_label=training_config["multi_label"],
            logging_steps=0,
        )

        df = pd.read_csv(str(DATA_PATH / "data.csv"), header=None)
        df = df.iloc[:, 0:1]
        # if first row is header, remove it
        if df.iloc[0, 0] == "text":
            df = df.iloc[1:]
        df.columns = ["text"]
        df.dropna(subset=["text"], inplace=True)

        texts = list(df["text"].values)

        predictions = learner.predict_batch(texts)

        processed_predictions = process_batch_results(
            texts, results=predictions, labels_count=training_config["labels_count"]
        )

        # save test results with model outcome
        pd.DataFrame(processed_predictions).to_csv(
            os.path.join(output_path, "out.csv"), index=None
        )

    except Exception as e:
        # Write out an error file. This will be returned as the failureReason in the
        # DescribeTrainingJob result.
        trc = traceback.format_exc()
        with open(os.path.join(output_path, "failure"), "w") as s:
            s.write("Exception during batch inference: " + str(e) + "\n" + trc)
        # Printing this causes the exception to be in the training job logs, as well.
        logger.error(
            "Exception during training: " + str(e) + "\n" + trc, file=sys.stderr
        )
        # A non-zero exit code causes the training job to be marked as Failed.
        sys.exit(255)


def process_batch_results(texts, results, labels_count=None):
    processed_results = []
    for i, result in enumerate(results):
        processed = OrderedDict()
        processed["text"] = texts[i]
        result = result[:labels_count] if labels_count else result
        for index, label in enumerate(result):
            processed["label_{}".format(index + 1)] = label[0]
            processed["confidence_{}".format(index + 1)] = label[1]
        processed_results.append(processed)

    return processed_results


def random_seed(seed_value):
    random.seed(seed_value)  # Python
    np.random.seed(seed_value)  # cpu vars

    torch.manual_seed(seed_value)  # cpu  vars

    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed_value)
        torch.cuda.manual_seed_all(seed_value)  # gpu vars
        torch.backends.cudnn.deterministic = True  # needed
        torch.backends.cudnn.benchmark = False


if __name__ == "__main__":
    train()

    # A zero exit code causes the job to be marked a Succeeded.
    sys.exit(0)
